package gorgonia

import (
	"fmt"
	"io/ioutil"
	"testing"

	"github.com/pkg/errors"
	"github.com/stretchr/testify/assert"
	"gorgonia.org/tensor"
)

var testCasesSoftMaxDo = []struct {
	input    []float64
	shape    tensor.Shape
	expected []float64
}{
	{
		[]float64{0.2094, -1.0, 0.6411, 0.0, -0.3909}, tensor.Shape{5}, []float64{0.2382105379413429, 0.07107636737487558, 0.36681399568548617, 0.19320559786800362, 0.13069350113029174},
	},
	{
		[]float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, tensor.Shape{10}, []float64{7.801341612780742e-05, 0.00021206245143623275, 0.0005764455082375902, 0.0015669413501390804, 0.004259388198344144, 0.0115782175399118, 0.031472858344688034, 0.08555209892803112, 0.23255471590259755, 0.6321492583604866},
	},
	{
		[]float64{0.1, 0.1, 0.1}, tensor.Shape{3}, []float64{0.3333333333333333, 0.3333333333333333, 0.3333333333333333},
	},
	{
		[]float64{-0.1, 0.3, -1.1, 2.7}, tensor.Shape{4}, []float64{0.05180179352659075, 0.07727919496508177, 0.019056814854240642, 0.8518621966540868},
	},
	{
		input: []float64{
			0.0878, 0.0238, -0.034, -0.0281, 0.0148, 0.219, -0.0949, 0.0735, -0.0224, 0.00209, -0.07, 0.0604, -0.0134, 0.132,
			0.0832, 0.0226, -0.0318, -0.0277, 0.0171, 0.213, -0.0934, 0.0722, -0.0233, 0.000468, -0.0682, 0.0598, -0.014, 0.128,
			0.0839, 0.023, -0.0331, -0.0261, 0.0139, 0.207, -0.0908, 0.0701, -0.0206, 0.00072, -0.0662, 0.0579, -0.0133, 0.125,
			0.0871, 0.0234, -0.0344, -0.0289, 0.0173, 0.224, -0.0961, 0.075, -0.0247, 0.00173, -0.0712, 0.0615, -0.0151, 0.135,
			0.0881, 0.0241, -0.0334, -0.0291, 0.0174, 0.223, -0.0977, 0.0759, -0.0238, 0.000915, -0.0708, 0.0626, -0.014, 0.135,
			0.0862, 0.0241, -0.0335, -0.0275, 0.0158, 0.216, -0.0938, 0.0727, -0.0227, 0.00159, -0.0699, 0.06, -0.0147, 0.131,
			0.0896, 0.0243, -0.0342, -0.0288, 0.0164, 0.224, -0.0973, 0.0755, -0.0232, 0.00216, -0.071, 0.0622, -0.0135, 0.136,
			0.0926, 0.0253, -0.0356, -0.0297, 0.0174, 0.234, -0.101, 0.079, -0.0244, 0.00166, -0.0744, 0.065, -0.0147, 0.142,
			0.0917, 0.0251, -0.0347, -0.0291, 0.0166, 0.229, -0.1, 0.0779, -0.0233, 0.00135, -0.0731, 0.0643, -0.0139, 0.139,
			0.0908, 0.0249, -0.0345, -0.0296, 0.0184, 0.232, -0.1, 0.0783, -0.0249, 0.0015, -0.074, 0.0644, -0.0148, 0.14,
			0.0866, 0.0232, -0.0339, -0.0267, 0.0149, 0.216, -0.094, 0.073, -0.0213, 0.00114, -0.0682, 0.06, -0.0128, 0.131,
			0.0836, 0.0232, -0.0329, -0.0279, 0.017, 0.215, -0.0929, 0.0719, -0.0238, 0.00139, -0.0692, 0.0597, -0.0151, 0.13,
			0.0921, 0.025, -0.0353, -0.0295, 0.0174, 0.233, -0.101, 0.0788, -0.0243, 0.00159, -0.0741, 0.0647, -0.0144, 0.141,
			0.086, 0.0232, -0.0334, -0.027, 0.0149, 0.214, -0.0927, 0.0722, -0.0216, 0.00198, -0.068, 0.0589, -0.0133, 0.13,
			0.0893, 0.0239, -0.0332, -0.0287, 0.0172, 0.224, -0.0978, 0.0763, -0.0234, 0.00142, -0.0714, 0.0627, -0.0132, 0.136,
			0.0891, 0.0242, -0.0348, -0.029, 0.0171, 0.226, -0.0981, 0.0766, -0.0241, 0.00134, -0.0721, 0.0628, -0.0148, 0.137,
			0.0869, 0.0231, -0.0339, -0.0275, 0.0155, 0.218, -0.0945, 0.0738, -0.022, 0.00168, -0.0691, 0.0598, -0.0133, 0.132,
			0.084, 0.023, -0.0311, -0.0275, 0.0178, 0.215, -0.0941, 0.0726, -0.0237, 0.000525, -0.069, 0.0602, -0.0134, 0.131,
			0.0873, 0.0231, -0.0336, -0.0285, 0.0166, 0.221, -0.0958, 0.0746, -0.0239, 0.00139, -0.0703, 0.0615, -0.0137, 0.134,
			0.0861, 0.0231, -0.0325, -0.0283, 0.0174, 0.219, -0.0959, 0.0741, -0.0238, 0.00106, -0.0695, 0.0611, -0.0138, 0.133,
			0.0843, 0.0228, -0.0322, -0.0272, 0.0141, 0.21, -0.0917, 0.0712, -0.0214, 0.000737, -0.0674, 0.0589, -0.0131, 0.127,
			0.0857, 0.0228, -0.0316, -0.0282, 0.0156, 0.214, -0.0939, 0.0729, -0.0225, 0.000756, -0.0681, 0.0602, -0.0128, 0.129,
			0.0894, 0.0244, -0.0353, -0.0283, 0.0164, 0.226, -0.098, 0.0765, -0.0232, 0.00113, -0.0721, 0.0628, -0.0147, 0.137,
			0.0792, 0.0213, -0.0316, -0.0246, 0.0139, 0.2, -0.0862, 0.0661, -0.021, 0.00137, -0.0644, 0.0545, -0.0142, 0.121,
			0.0847, 0.0231, -0.0323, -0.0266, 0.0159, 0.213, -0.0924, 0.072, -0.0218, 0.00133, -0.0685, 0.0588, -0.0133, 0.129,
			0.0842, 0.0223, -0.0332, -0.0266, 0.014, 0.21, -0.0914, 0.0708, -0.0216, 0.000902, -0.0666, 0.0587, -0.0133, 0.126,
			0.0914, 0.0248, -0.0355, -0.029, 0.0169, 0.231, -0.0997, 0.078, -0.0238, 0.00171, -0.0735, 0.0637, -0.0144, 0.14,
			0.0851, 0.023, -0.0327, -0.028, 0.0175, 0.22, -0.0947, 0.0736, -0.0243, 0.00106, -0.0704, 0.0605, -0.0146, 0.133,
			0.0856, 0.0231, -0.0332, -0.0271, 0.0166, 0.217, -0.0941, 0.0734, -0.0227, 0.00107, -0.069, 0.0601, -0.0139, 0.132,
			0.0825, 0.0223, -0.0329, -0.026, 0.0151, 0.209, -0.0896, 0.0692, -0.0221, 0.00171, -0.0669, 0.0568, -0.0145, 0.126,
			0.0856, 0.0232, -0.0322, -0.0287, 0.0172, 0.218, -0.095, 0.074, -0.024, 0.00106, -0.07, 0.0606, -0.0141, 0.132,
			0.0819, 0.0224, -0.0316, -0.0255, 0.015, 0.205, -0.0903, 0.0691, -0.0205, 0.000611, -0.0651, 0.0577, -0.0123, 0.124,
		},
		shape: tensor.Shape{32, 14},
		expected: []float64{
			0.07580373195832904, 0.07110427955999997, 0.06711097052129973, 0.06750809561440703, 0.07046721214751539, 0.08643109662652378, 0.06314587498214086, 0.07472745233129745, 0.06789399051506614, 0.06957724162455223, 0.06473794629960379, 0.07375490678724296, 0.06850779440403143, 0.07922940662799016,
			0.07553698070492218, 0.07109537937458059, 0.0673311072133112, 0.06760773144505652, 0.07070542813692551, 0.08600645460624466, 0.06330867382399817, 0.07471062719385999, 0.06790586086716144, 0.0695391808527201, 0.06492432409851422, 0.07378993550224344, 0.06854033108680987, 0.07899798509365204,
			0.07563735949020686, 0.07116850209805829, 0.06728587505246605, 0.06775852853502212, 0.0705238065426734, 0.0855456656279284, 0.06351336360776534, 0.07460073310252557, 0.06813222717118268, 0.06960040137967038, 0.06509516878658524, 0.0736961334366599, 0.06863141223823133, 0.07881082293102451,
			0.07571828316350099, 0.07104543859810176, 0.06705543414453265, 0.0674252551077198, 0.07061338053947136, 0.08682717533785694, 0.06304316513167385, 0.07480761260501406, 0.06770903670336426, 0.06952244517535047, 0.06463264686736289, 0.07380449610600028, 0.06836217349623949, 0.07943345702381124,
			0.07577489761197814, 0.07107723279896462, 0.06710557139359428, 0.06739474663177814, 0.07060260711076131, 0.08671848536522016, 0.06292648068737672, 0.0748560601661504, 0.0677528870226157, 0.06944826394751061, 0.06464217565864186, 0.07386706595599397, 0.06842012946324129, 0.0794133961861729,
			0.07571667561503748, 0.07115769153033254, 0.06717481642562362, 0.0675790768927982, 0.07056952695018931, 0.0862110553459827, 0.06324388364833107, 0.07470136923223723, 0.06790423621996317, 0.06957382518787383, 0.06477362000090751, 0.07375866071284001, 0.06844964885138428, 0.07918591338649904,
			0.07586555623009507, 0.07106982018973229, 0.0670315079520724, 0.06739445717595188, 0.07051058051545611, 0.08677883684519459, 0.06293250335813079, 0.07480335798280527, 0.06777292485658129, 0.06951362501868301, 0.06460958515479787, 0.07381506007103907, 0.06843352093907455, 0.07946866371038583,
			0.07598825125930292, 0.07104253197884605, 0.066845149278942, 0.06724070139098821, 0.07048350703214519, 0.08752975035591505, 0.06261336411851758, 0.07496180668615737, 0.06759802317465682, 0.06938278206204422, 0.06430122868181473, 0.07391965348672717, 0.06825691445591649, 0.07983633603802633,
			0.0759650476816685, 0.07107057060457, 0.06694513046425878, 0.06732107485668994, 0.07046903091985701, 0.08714499456226771, 0.06271328678119782, 0.07492393025637709, 0.06771267162370743, 0.06940253092966628, 0.06442316900544866, 0.07391186246516075, 0.06835217167838212, 0.07964452817074806,
			0.07588261555515456, 0.07104316244539605, 0.06694608634866177, 0.06727492717383647, 0.07058287942988004, 0.08739059027411504, 0.06270164065844434, 0.07493998656568233, 0.06759186354860738, 0.06940005181302644, 0.06435326134422872, 0.07390552690268615, 0.0682780005294755, 0.07970940741080541,
			0.07572849056812758, 0.07107633577009935, 0.0671315717478957, 0.06761666329846892, 0.07048884364820387, 0.08619002495303499, 0.06321581145869905, 0.0747055548262468, 0.06798278090816147, 0.06952555975158711, 0.06486800099339428, 0.07374066796687649, 0.06856309739697362, 0.07916659671223078,
			0.07555205748210254, 0.07112379296811323, 0.06724360504758324, 0.06758066502054558, 0.07068418963025343, 0.08616136899703665, 0.06332764234416066, 0.07467324946146231, 0.0678583145397035, 0.06958937667642051, 0.06484643405936474, 0.07376777047080536, 0.06845125743782313, 0.07914027586462503,
			0.07596459029290821, 0.07103461649660746, 0.06687781611862521, 0.06726683451491486, 0.07049679970373515, 0.08745875532447635, 0.0626251725389753, 0.0749609502427929, 0.06761753308042232, 0.06939100960430751, 0.06433265232214526, 0.07391141743871536, 0.06829027121711537, 0.07977158110425897,
			0.0757076953146851, 0.07109946489485341, 0.06718700230148886, 0.06761837804617317, 0.07051178159570559, 0.08604580843717952, 0.06331864342633538, 0.07467010495930128, 0.0679845049405494, 0.06960662925274143, 0.06490208896853973, 0.07368356757926006, 0.06855112455007047, 0.07911320573311643,
			0.07583718024893207, 0.07103613398831786, 0.06709360118049353, 0.06739620272864916, 0.0705617827417398, 0.08677240678919816, 0.0628963841896962, 0.07485767746839475, 0.0677543508572872, 0.06945705701975846, 0.06457896102358412, 0.07384650461571897, 0.06844898183156023, 0.07946277531666916,
			0.07582745683861136, 0.07106254853366956, 0.06699114553624523, 0.06738082315304855, 0.07055979133911858, 0.08695236626183257, 0.06288203145680492, 0.07488551304172053, 0.06771179941611492, 0.06945648591065866, 0.06453840380731399, 0.07385919087239427, 0.0683444564460048, 0.07954798738646204,
			0.07573860930571823, 0.07105740424606419, 0.06712040263541186, 0.06755135077535768, 0.07051941492274218, 0.08634820879904852, 0.06317369908893183, 0.0747529039899984, 0.06792390679452369, 0.0695515400287077, 0.06479886325742716, 0.07371365505105491, 0.0685174228348136, 0.0792326182702,
			0.07555287707789624, 0.07108190258682497, 0.06733854251103616, 0.06758139814192747, 0.07071323605707901, 0.0861278456542269, 0.06322708501847565, 0.07469646510239021, 0.06783869601120339, 0.06950215572849829, 0.06483416938890552, 0.07377594793641269, 0.06854104547049059, 0.07918863331463287,
			0.07574410047570453, 0.07103413668305468, 0.06711855678038553, 0.06746173578258094, 0.07057391214473711, 0.08657928293302596, 0.06307096744900711, 0.0747882330057183, 0.06777277460801437, 0.06950860513842702, 0.06469995848466632, 0.07381489642741564, 0.06846759446621521, 0.07936524562104719,
			0.0756717060516291, 0.07105145454009963, 0.06720880907724458, 0.06749167968783155, 0.07064761329018292, 0.08642736294272377, 0.06308003559618351, 0.07476907221361767, 0.06779607662586773, 0.06950261142997739, 0.06476752539297144, 0.07380336497224824, 0.06847743852360876, 0.07930524965581359,
			0.07563645146680772, 0.07112495994208998, 0.06731871823397463, 0.06765615471335058, 0.07050885672565142, 0.08576734410785562, 0.06343008839586677, 0.07465207569127591, 0.06804970059049206, 0.06957291430008673, 0.06499031957843418, 0.07373947913462867, 0.06861686357579774, 0.07893607354368812,
			0.07569158658182383, 0.07107722860166533, 0.06731391746556793, 0.0675431743007199, 0.07056731046387382, 0.08605331213010359, 0.06324822175577374, 0.07472890855666557, 0.06792926972082854, 0.0695275455347745, 0.06490125835629644, 0.07378585249938487, 0.06859138972758888, 0.0790410243049331,
			0.07584616545178219, 0.07107297386680439, 0.06695408929827328, 0.06742441213279868, 0.07050665835825033, 0.08694773144811387, 0.0628849678442258, 0.0748740336485852, 0.06776915498171088, 0.06943820012494253, 0.06453496372290675, 0.07385525396295956, 0.06834764791099958, 0.07954374724764714,
			0.07538565512371828, 0.07114478361763509, 0.0674790383417934, 0.0679530487109372, 0.07062025536698704, 0.08506511368360894, 0.06389345986184054, 0.07440454335436263, 0.06819812055093102, 0.06974090420713536, 0.06530163057984227, 0.07354643728883982, 0.06866344809125371, 0.07860356122111473,
			0.07562397749379275, 0.07110611903737409, 0.0672739706266011, 0.0676585272042326, 0.0705959936354916, 0.08597644776224045, 0.06334990426233511, 0.0746696259392387, 0.06798406880962622, 0.06957486697999139, 0.0648822050302828, 0.07369046356590116, 0.06856439629225838, 0.07904943336063361,
			0.07563986730000552, 0.07109972645158245, 0.06726119610287586, 0.06770658817421822, 0.07051204099049824, 0.08577979501497728, 0.06345833123151746, 0.07463305379391726, 0.06804596885976037, 0.06959449639477064, 0.06505177487882444, 0.07373543538604437, 0.06861310074280513, 0.07886862467820287,
			0.07593485247268887, 0.07104232089240368, 0.06688506966894726, 0.06732123863525878, 0.07048329760671607, 0.08731093985767327, 0.06272598329217195, 0.0749241125311565, 0.06767222083901234, 0.06942074681995235, 0.064391122122705, 0.07386032195241513, 0.06831133886358241, 0.07971643444531631,
			0.07560629485457866, 0.07105395691652636, 0.06720445533714883, 0.06752105971470501, 0.07066423287702295, 0.08652553260367433, 0.06316431748750224, 0.07474180282034944, 0.06777135038785487, 0.06951201011082883, 0.06471801132999752, 0.07376907051087278, 0.06843193112369053, 0.0793159739252478,
			0.07565081367829454, 0.07106736258185826, 0.06717681653743945, 0.0675878474881799, 0.07060692277557727, 0.08627399291946071, 0.06320783063962776, 0.07473348085957007, 0.06788588922811524, 0.06951886787922641, 0.06481442560885718, 0.07374610616056912, 0.0684859213223351, 0.07924372232088901,
			0.07553407196386892, 0.07112108441337289, 0.06730158843979399, 0.06776757520556947, 0.07061085164775649, 0.08571980011329053, 0.06359175596462249, 0.0745361198985853, 0.06803238479191852, 0.0696716741696257, 0.06505179760274592, 0.07361757873659114, 0.06855140067851208, 0.0788923163737467,
			0.07565440995291926, 0.07107784839868066, 0.06724722358791015, 0.06748300124067051, 0.0706526581545908, 0.08636441544447576, 0.0631539712428355, 0.07478188920166674, 0.06780091786533832, 0.06952147743627629, 0.06475272163644419, 0.07378649591573709, 0.06847548052790334, 0.07924748939455148,
			0.07549697252896907, 0.07113593019535562, 0.06739646419295066, 0.06780883908925636, 0.07061146721822258, 0.08538688832362386, 0.06355416640221782, 0.07453676968870958, 0.06814873230964302, 0.06960271369819923, 0.06517608169742432, 0.07369187556096998, 0.0687098493503266, 0.07874324974413119},
	},
}

func TestSoftMaxFull(t *testing.T) {
	testCases := []struct {
		Dtype    tensor.Dtype
		XInit    InitWFn
		XShape   tensor.Shape
		Expected tensor.Tensor
		IsLog    bool
	}{
		{
			Dtype:    tensor.Float64,
			XInit:    RangedFromWithStep(0.0, 0.01),
			XShape:   tensor.Shape{2, 3},
			IsLog:    false,
			Expected: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{0.49932500041006217, 0.5006749995899378, 0.49730002624369385, 0.502699973756306})),
		},
		{
			Dtype:    tensor.Float32,
			XInit:    RangedFromWithStep(0.0, 0.01),
			XShape:   tensor.Shape{2, 3},
			IsLog:    false,
			Expected: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float32{0.49932498, 0.50067496, 0.49730003, 0.5027})),
		},
		{
			Dtype:    tensor.Float64,
			XInit:    RangedFromWithStep(0.0, 0.01),
			XShape:   tensor.Shape{2, 3},
			IsLog:    true,
			Expected: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{-0.6944980918096686, -0.6917980918096686, -0.6985617604890871, -0.6877617604890871})),
		},
		{
			Dtype:    tensor.Float32,
			XInit:    RangedFromWithStep(0.0, 0.01),
			XShape:   tensor.Shape{2, 3},
			IsLog:    true,
			Expected: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float32{-0.6944981, -0.69179815, -0.6985617, -0.6877617})),
		},
	}
	for i, tC := range testCases {
		t.Run(fmt.Sprintf("#%d %v", i+1, tC.XShape), func(t *testing.T) {
			c := assert.New(t)

			g := NewGraph()

			x := NewTensor(g, tC.Dtype, 2, WithShape(tC.XShape...), WithInit(tC.XInit), WithName("x"))
			w := NewTensor(g, tC.Dtype, 2, WithShape(tC.XShape...), WithInit(RangedFromWithStep(-0.05, 0.03)), WithName("w"))

			t.Logf("Input: %v", x.Value())

			optim := NewAdamSolver(WithLearnRate(0.1))

			wT := Must(Transpose(w, 1, 0))

			output := Must(Mul(x, wT))

			var fcVal Value
			Read(output, &fcVal)

			softMaxFn := SoftMax
			if tC.IsLog {
				softMaxFn = LogSoftMax
			}

			output = Must(softMaxFn(output))

			cost := Must(Mean(output))

			_, err := Grad(cost, x, w)
			c.NoError(err)

			vm := NewTapeMachine(g, BindDualValues(w))
			c.NoError(vm.RunAll())

			t.Logf("dx: %v", x.Deriv().Value())

			c.NoError(optim.Step(NodesToValueGrads(Nodes{w})))

			t.Logf("wT: %v", wT.Value())
			t.Logf("output: %v", output.Value())
			t.Logf("FC Val: %v", fcVal)
			t.Logf("cost: %v", cost.Value())
			t.Logf("w: %v", w.Value())

			c.InDeltaSlice(tC.Expected.Data(), output.Value().Data(), 1e-5, "got=%#v\nexpected=%#v", output.Value().Data(), tC.Expected.Data())
			c.Equal(tC.Expected.Shape(), output.Shape())
		})
	}
}

func TestSoftmaxDo(t *testing.T) {
	assert := assert.New(t)

	for i, testCase := range testCasesSoftMaxDo {
		tt := tensor.New(tensor.Of(tensor.Float64), tensor.WithShape(testCase.shape...), tensor.WithBacking(testCase.input))
		op := newSoftmaxOp(tt.Shape())

		out, err := op.Do(tt)

		t.Logf("out: %v", out)

		assert.NoError(err, "failed test case: %d", i)
		assert.InDeltaSlice(testCase.expected, out.Data().([]float64), 1e-7)
	}
}

// func TestSoftmaxKernel(t *testing.T) {
// 	// this test is used for migrating to a new algorithm for softmax
// 	assert := assert.New(t)
// 	a := tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{-0.1, 0.3, -1.1, 2.7, 3.14, 0.1}))
// 	op := newSoftmaxOp(a.Shape())
// 	op.axis = 0
// 	b0, _ := op.Do(a)
// 	op.axis = 1
// 	b1, _ := op.Do(a)

// 	// across axis 0
// 	out := make([]float64, 6)
// 	op.do(tensor.Shape{2, 3}, 0, a.Data().([]float64), out)
// 	assert.True(floatsEqual64(out, b0.Data().([]float64)))
// 	t.Logf("\n%v\n%v", out, b0.Data())

// 	// acros axis 1
// 	out = make([]float64, 6)
// 	op.do(tensor.Shape{2, 3}, 1, a.Data().([]float64), out)
// 	assert.True(floatsEqual64(out, b1.Data().([]float64)))
// 	/*
// 		// super large
// 		a = tensor.New(tensor.WithShape(10, 1024, 2048, 30), tensor.WithBacking(Uniform64(-1, 1, 10, 1024, 2048, 30)))
// 		op = newSoftmaxOp(a.Shape())
// 		op.axis = 0
// 		b, _ := op.Do(a)

// 		out = make([]float64, 10*1024*2048*30)
// 		op.doF64s(tensor.Shape{10, 1024, 2048, 30}, 0, a.Data().([]float64), out)
// 		assert.True(floatsEqual64(out, b.Data().([]float64)))
// 	*/
// }

func oldsoftmax(a *Node, axes ...int) (retVal *Node, err error) {
	aShape := a.Shape()
	axis := aShape.Dims() - 1 // default: last dim
	if a.IsColVec() || (a.IsVector() && !a.IsRowVec()) {
		axis = 0
	}

	if len(axes) > 0 {
		if axes[0] >= axis+1 || axes[0] < 0 {
			return nil, errors.Errorf("Cannot perform SoftMax on axis %d. Input has shape %v", axes[0], a.Shape())
		}
		axis = axes[0]
	}

	var exp, sum *Node
	if exp, err = Exp(a); err != nil {
		return nil, errors.Wrap(err, operationError)
	}
	if sum, err = Sum(exp, axis); err != nil {
		return nil, errors.Wrap(err, operationError)
	}

	if sum.IsScalar() {
		return HadamardDiv(exp, sum)
	}

	// reshape if necessary
	ss := sum.Shape()
	diff := exp.Shape().Dims() - ss.Dims()

	// TODO: multirank softmax
	if diff > 0 {
		newShape := tensor.Shape(tensor.BorrowInts(ss.Dims() + diff))
		copy(newShape, ss)
		copy(newShape[axis+1:], newShape[axis:])
		newShape[axis] = 1

		if sum, err = Reshape(sum, newShape); err != nil {
			return nil, errors.Wrap(err, "Failed to reshape")
		}
	}

	return BroadcastHadamardDiv(exp, sum, nil, []byte{byte(axis)})
}

func TestOld_NewSoftmax(t *testing.T) {
	a := tensor.New(tensor.WithBacking([]float64{0.1, 0.1, 0.3, 0.1, 0.4}))

	g := NewGraph()
	A := NodeFromAny(g, a, WithName("A"))
	sm := Must(SoftMax(A))
	sum := Must(Sum(sm))
	if _, err := Grad(sum, A); err != nil {
		t.Fatal(err)
	}

	h := NewGraph()
	A2 := NodeFromAny(h, a, WithName("A"))
	sm2 := Must(oldsoftmax(A2))
	sum2 := Must(Sum(sm2))
	if _, err := Grad(sum2, A2); err != nil {
		t.Fatal(err)
	}

	m1 := NewTapeMachine(g, TraceExec(), BindDualValues())
	if err := m1.RunAll(); err != nil {
		t.Fatalf("m1 %v", err)
	}

	m2 := NewTapeMachine(h, TraceExec(), BindDualValues())
	if err := m2.RunAll(); err != nil {
		t.Fatalf("m2 %v", err)
	}

	Agrad, err := A.Grad()
	if err != nil {
		t.Fatalf("No grad for A %v", err)
	}

	A2grad, err := A2.Grad()
	if err != nil {
		t.Fatalf("No grad for A2 %v", err)
	}

	t.Logf("\n%v\n%v", sm.Value(), sm2.Value())
	t.Logf("\n%v\n%v", Agrad, A2grad)

	ioutil.WriteFile("oldsm.dot", []byte(h.ToDot()), 0644)
	ioutil.WriteFile("newsm.dot", []byte(g.ToDot()), 0644)

}

func BenchmarkSoftmaxLargeOldAxis0(b *testing.B) {
	b.StopTimer()
	a := tensor.New(tensor.WithShape(10, 1024, 2048, 30), tensor.WithBacking(Uniform64(-1, 1, 10, 1024, 2048, 30)))
	op := newSoftmaxOp(a.Shape())
	op.axis = 0
	var v Value

	b.ResetTimer()
	b.StartTimer()
	for i := 0; i < b.N; i++ {
		v, _ = op.Do(a)
	}
	_ = v
}

// func BenchmarkSoftmaxLargeNewAxis0(b *testing.B) {
// 	b.StopTimer()
// 	a := tensor.New(tensor.WithShape(10, 1024, 2048, 30), tensor.WithBacking(Uniform64(-1, 1, 10, 1024, 2048, 30)))
// 	op := newSoftmaxOp(a.Shape())
// 	op.axis = 0
// 	out := make([]float64, len(a.Data().([]float64)))

// 	b.ResetTimer()
// 	b.StartTimer()
// 	for i := 0; i < b.N; i++ {
// 		op.do(a.Shape(), 0, a.Data().([]float64), out)
// 	}

// }

// func BenchmarkSoftmaxMedOldAxis0(b *testing.B) {
// 	b.StopTimer()
// 	a := tensor.New(tensor.WithShape(1200, 2500), tensor.WithBacking(Uniform64(-1, 1, 1200, 2500)))
// 	op := newSoftmaxOp(a.Shape())
// 	op.axis = 0
// 	var v Value

// 	b.ResetTimer()
// 	b.StartTimer()
// 	for i := 0; i < b.N; i++ {
// 		v, _ = op.Do(a)
// 	}
// 	_ = v
// }

// func BenchmarkSoftmaxMedNewAxis0(b *testing.B) {
// 	b.StopTimer()
// 	a := tensor.New(tensor.WithShape(1200, 2500), tensor.WithBacking(Uniform64(-1, 1, 1200, 2500)))
// 	op := newSoftmaxOp(a.Shape())
// 	op.axis = 0
// 	out := make([]float64, len(a.Data().([]float64)))

// 	b.ResetTimer()
// 	b.StartTimer()
// 	for i := 0; i < b.N; i++ {
// 		op.do(a.Shape(), 0, a.Data().([]float64), out)
// 	}

// }
